{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6ab4b8dd",
   "metadata": {},
   "source": [
    "我们将使用GroupLens研究小组收集的MovieLens数据集。这个数据集描述了MovieLens的五星评级和标记活动。该数据集包含来自600多名用户的9000多部电影的约10万个评分。我们将使用该数据集生成两种节点类型，分别保存电影和用户的数据，以及一种连接用户和电影的边类型，表示用户对特定电影的评分关系。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d68b3ae6",
   "metadata": {},
   "source": [
    "首先，我们将数据集下载到任意文件夹（在本例中为当前目录）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3c5ba395",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using existing file ml-latest-small.zip\n",
      "Extracting ./data\\ml-latest-small.zip\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.data import download_url, extract_zip\n",
    "\n",
    "url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'\n",
    "extract_zip(download_url(url, './data'), './data')\n",
    "\n",
    "movies_path = './data/ml-latest-small/movies.csv'\n",
    "ratings_path = './data/ml-latest-small/ratings.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c9584b90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9742\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>movieId</th>\n",
       "      <th>title</th>\n",
       "      <th>genres</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Toy Story (1995)</td>\n",
       "      <td>Adventure|Animation|Children|Comedy|Fantasy</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>Jumanji (1995)</td>\n",
       "      <td>Adventure|Children|Fantasy</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Grumpier Old Men (1995)</td>\n",
       "      <td>Comedy|Romance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>Waiting to Exhale (1995)</td>\n",
       "      <td>Comedy|Drama|Romance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>Father of the Bride Part II (1995)</td>\n",
       "      <td>Comedy</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>Heat (1995)</td>\n",
       "      <td>Action|Crime|Thriller</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>Sabrina (1995)</td>\n",
       "      <td>Comedy|Romance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>Tom and Huck (1995)</td>\n",
       "      <td>Adventure|Children</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>Sudden Death (1995)</td>\n",
       "      <td>Action</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>GoldenEye (1995)</td>\n",
       "      <td>Action|Adventure|Thriller</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   movieId                               title  \\\n",
       "0        1                    Toy Story (1995)   \n",
       "1        2                      Jumanji (1995)   \n",
       "2        3             Grumpier Old Men (1995)   \n",
       "3        4            Waiting to Exhale (1995)   \n",
       "4        5  Father of the Bride Part II (1995)   \n",
       "5        6                         Heat (1995)   \n",
       "6        7                      Sabrina (1995)   \n",
       "7        8                 Tom and Huck (1995)   \n",
       "8        9                 Sudden Death (1995)   \n",
       "9       10                    GoldenEye (1995)   \n",
       "\n",
       "                                        genres  \n",
       "0  Adventure|Animation|Children|Comedy|Fantasy  \n",
       "1                   Adventure|Children|Fantasy  \n",
       "2                               Comedy|Romance  \n",
       "3                         Comedy|Drama|Romance  \n",
       "4                                       Comedy  \n",
       "5                        Action|Crime|Thriller  \n",
       "6                               Comedy|Romance  \n",
       "7                           Adventure|Children  \n",
       "8                                       Action  \n",
       "9                    Action|Adventure|Thriller  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "print(len(pd.read_csv(movies_path)))\n",
    "pd.read_csv(movies_path).head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "979fad72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100836\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>4.0</td>\n",
       "      <td>964982703</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4.0</td>\n",
       "      <td>964981247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>4.0</td>\n",
       "      <td>964982224</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>47</td>\n",
       "      <td>5.0</td>\n",
       "      <td>964983815</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>50</td>\n",
       "      <td>5.0</td>\n",
       "      <td>964982931</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>70</td>\n",
       "      <td>3.0</td>\n",
       "      <td>964982400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>101</td>\n",
       "      <td>5.0</td>\n",
       "      <td>964980868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>110</td>\n",
       "      <td>4.0</td>\n",
       "      <td>964982176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>151</td>\n",
       "      <td>5.0</td>\n",
       "      <td>964984041</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>157</td>\n",
       "      <td>5.0</td>\n",
       "      <td>964984100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating  timestamp\n",
       "0       1        1     4.0  964982703\n",
       "1       1        3     4.0  964981247\n",
       "2       1        6     4.0  964982224\n",
       "3       1       47     5.0  964983815\n",
       "4       1       50     5.0  964982931\n",
       "5       1       70     3.0  964982400\n",
       "6       1      101     5.0  964980868\n",
       "7       1      110     4.0  964982176\n",
       "8       1      151     5.0  964984041\n",
       "9       1      157     5.0  964984100"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(len(pd.read_csv(ratings_path)))\n",
    "pd.read_csv(ratings_path).head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d68b2b5d",
   "metadata": {},
   "source": [
    "为了用PyG数据格式表示这些数据，我们首先定义了一个方法load_node_csv()，该方法读取*.csv文件并返回形状为[num_nodes，num_features]的节点级特征表示x："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5dddf953",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def load_node_csv(path, index_col, encoders=None, **kwargs): # **kwargs用于在函数定义中接收任意数量的关键字参数，是一个字典\n",
    "    df = pd.read_csv(path, index_col=index_col, **kwargs) # 读取*.csv\n",
    "    mapping = {index: i for i, index in enumerate(df.index.unique())} # 将索引映射成连续值\n",
    "\n",
    "    x = None\n",
    "    if encoders is not None:\n",
    "        xs = [encoder(df[col]) for col, encoder in encoders.items()]\n",
    "        x = torch.cat(xs, dim=-1)\n",
    "\n",
    "    return x, mapping"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddf8d849",
   "metadata": {},
   "source": [
    "这里，load_node_csv()从路径读取*.csv文件，并创建一个字典映射，将其索引列映射到范围｛0，…，num_rows-1｝中的连续值。这是必要的，因为我们希望我们的最终数据表示尽可能紧凑，例如，第一行中的电影表示应该可以通过x[0]访问。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6cb88662",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "class SequenceEncoder:\n",
    "    def __init__(self, model_name='all-MiniLM-L6-v2', device=None):\n",
    "        self.device = device\n",
    "        self.model = SentenceTransformer(model_name, device=device)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def __call__(self, df):\n",
    "        x = self.model.encode(df.values, show_progress_bar=True,\n",
    "                              convert_to_tensor=True, device=self.device)\n",
    "        print(x.shape)\n",
    "        return x.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "190431d6",
   "metadata": {},
   "source": [
    "SequenceEncoder类加载一个由model_name给定的预先训练的NLP模型，并使用它将字符串列表编码为形状为[num_strings,embedding_dim]的PyTorch张量。我们可以使用此SequenceEncoder对movies.csv文件的标题进行编码。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c0f5038",
   "metadata": {},
   "source": [
    "以类似的方式，我们可以创建另一个编码器，将电影类型转换为分类标签。为此，我们首先需要找到数据中存在的所有电影类型，创建shape[num_movies，num_genres]的特征表示x，并在类型j存在于电影i中的情况下将1分配给x[i，j]："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d75225d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GenresEncoder:\n",
    "    def __init__(self, sep='|'):\n",
    "        self.sep = sep\n",
    "\n",
    "    def __call__(self, df):\n",
    "        genres = set(g for col in df.values for g in col.split(self.sep))\n",
    "        mapping = {genre: i for i, genre in enumerate(genres)}\n",
    "\n",
    "        x = torch.zeros(len(df), len(mapping))\n",
    "        for i, col in enumerate(df.values):\n",
    "            for genre in col.split(self.sep):\n",
    "                x[i, mapping[genre]] = 1\n",
    "                \n",
    "        print(x.shape)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de36df1b",
   "metadata": {},
   "source": [
    "有了这个，我们可以通过以下方式获得我们对电影的最终呈现："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f5f04257",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3acfef11c6ef4ca0a33ca24398d271d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/305 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([9742, 384])\n",
      "torch.Size([9742, 20])\n",
      "tensor([[-0.0828,  0.0530,  0.0536,  ...,  0.0000,  0.0000,  0.0000],\n",
      "        [-0.1053,  0.1508, -0.0264,  ...,  0.0000,  0.0000,  0.0000],\n",
      "        [-0.0988,  0.0176, -0.0527,  ...,  0.0000,  0.0000,  0.0000],\n",
      "        ...,\n",
      "        [-0.1115,  0.0310, -0.0177,  ...,  0.0000,  0.0000,  0.0000],\n",
      "        [ 0.0366,  0.0137,  0.0315,  ...,  0.0000,  0.0000,  0.0000],\n",
      "        [-0.0500, -0.0141, -0.0031,  ...,  0.0000,  0.0000,  0.0000]])\n"
     ]
    }
   ],
   "source": [
    "movie_x, movie_mapping = load_node_csv(\n",
    "    movies_path, index_col='movieId', encoders={\n",
    "        'title': SequenceEncoder(),\n",
    "        'genres': GenresEncoder()\n",
    "    })\n",
    "\n",
    "print(movie_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4bbf14f",
   "metadata": {},
   "source": [
    "类似地，我们也可以使用load_node_csv()来获得从userId到连续值的用户映射。但是，此数据集中没有用户的其他特征信息。因此，我们没有定义任何编码器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9ef92960",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, 31: 30, 32: 31, 33: 32, 34: 33, 35: 34, 36: 35, 37: 36, 38: 37, 39: 38, 40: 39, 41: 40, 42: 41, 43: 42, 44: 43, 45: 44, 46: 45, 47: 46, 48: 47, 49: 48, 50: 49, 51: 50, 52: 51, 53: 52, 54: 53, 55: 54, 56: 55, 57: 56, 58: 57, 59: 58, 60: 59, 61: 60, 62: 61, 63: 62, 64: 63, 65: 64, 66: 65, 67: 66, 68: 67, 69: 68, 70: 69, 71: 70, 72: 71, 73: 72, 74: 73, 75: 74, 76: 75, 77: 76, 78: 77, 79: 78, 80: 79, 81: 80, 82: 81, 83: 82, 84: 83, 85: 84, 86: 85, 87: 86, 88: 87, 89: 88, 90: 89, 91: 90, 92: 91, 93: 92, 94: 93, 95: 94, 96: 95, 97: 96, 98: 97, 99: 98, 100: 99, 101: 100, 102: 101, 103: 102, 104: 103, 105: 104, 106: 105, 107: 106, 108: 107, 109: 108, 110: 109, 111: 110, 112: 111, 113: 112, 114: 113, 115: 114, 116: 115, 117: 116, 118: 117, 119: 118, 120: 119, 121: 120, 122: 121, 123: 122, 124: 123, 125: 124, 126: 125, 127: 126, 128: 127, 129: 128, 130: 129, 131: 130, 132: 131, 133: 132, 134: 133, 135: 134, 136: 135, 137: 136, 138: 137, 139: 138, 140: 139, 141: 140, 142: 141, 143: 142, 144: 143, 145: 144, 146: 145, 147: 146, 148: 147, 149: 148, 150: 149, 151: 150, 152: 151, 153: 152, 154: 153, 155: 154, 156: 155, 157: 156, 158: 157, 159: 158, 160: 159, 161: 160, 162: 161, 163: 162, 164: 163, 165: 164, 166: 165, 167: 166, 168: 167, 169: 168, 170: 169, 171: 170, 172: 171, 173: 172, 174: 173, 175: 174, 176: 175, 177: 176, 178: 177, 179: 178, 180: 179, 181: 180, 182: 181, 183: 182, 184: 183, 185: 184, 186: 185, 187: 186, 188: 187, 189: 188, 190: 189, 191: 190, 192: 191, 193: 192, 194: 193, 195: 194, 196: 195, 197: 196, 198: 197, 199: 198, 200: 199, 201: 200, 202: 201, 203: 202, 204: 203, 205: 204, 206: 205, 207: 206, 208: 207, 209: 208, 210: 209, 211: 210, 212: 211, 213: 212, 214: 213, 215: 214, 216: 215, 217: 216, 218: 217, 219: 218, 220: 219, 221: 220, 222: 221, 223: 222, 224: 223, 225: 224, 226: 225, 227: 226, 228: 227, 229: 228, 230: 229, 231: 230, 232: 231, 233: 232, 234: 233, 235: 234, 236: 235, 237: 236, 238: 237, 239: 238, 240: 239, 241: 240, 242: 241, 243: 242, 244: 243, 245: 244, 246: 245, 247: 246, 248: 247, 249: 248, 250: 249, 251: 250, 252: 251, 253: 252, 254: 253, 255: 254, 256: 255, 257: 256, 258: 257, 259: 258, 260: 259, 261: 260, 262: 261, 263: 262, 264: 263, 265: 264, 266: 265, 267: 266, 268: 267, 269: 268, 270: 269, 271: 270, 272: 271, 273: 272, 274: 273, 275: 274, 276: 275, 277: 276, 278: 277, 279: 278, 280: 279, 281: 280, 282: 281, 283: 282, 284: 283, 285: 284, 286: 285, 287: 286, 288: 287, 289: 288, 290: 289, 291: 290, 292: 291, 293: 292, 294: 293, 295: 294, 296: 295, 297: 296, 298: 297, 299: 298, 300: 299, 301: 300, 302: 301, 303: 302, 304: 303, 305: 304, 306: 305, 307: 306, 308: 307, 309: 308, 310: 309, 311: 310, 312: 311, 313: 312, 314: 313, 315: 314, 316: 315, 317: 316, 318: 317, 319: 318, 320: 319, 321: 320, 322: 321, 323: 322, 324: 323, 325: 324, 326: 325, 327: 326, 328: 327, 329: 328, 330: 329, 331: 330, 332: 331, 333: 332, 334: 333, 335: 334, 336: 335, 337: 336, 338: 337, 339: 338, 340: 339, 341: 340, 342: 341, 343: 342, 344: 343, 345: 344, 346: 345, 347: 346, 348: 347, 349: 348, 350: 349, 351: 350, 352: 351, 353: 352, 354: 353, 355: 354, 356: 355, 357: 356, 358: 357, 359: 358, 360: 359, 361: 360, 362: 361, 363: 362, 364: 363, 365: 364, 366: 365, 367: 366, 368: 367, 369: 368, 370: 369, 371: 370, 372: 371, 373: 372, 374: 373, 375: 374, 376: 375, 377: 376, 378: 377, 379: 378, 380: 379, 381: 380, 382: 381, 383: 382, 384: 383, 385: 384, 386: 385, 387: 386, 388: 387, 389: 388, 390: 389, 391: 390, 392: 391, 393: 392, 394: 393, 395: 394, 396: 395, 397: 396, 398: 397, 399: 398, 400: 399, 401: 400, 402: 401, 403: 402, 404: 403, 405: 404, 406: 405, 407: 406, 408: 407, 409: 408, 410: 409, 411: 410, 412: 411, 413: 412, 414: 413, 415: 414, 416: 415, 417: 416, 418: 417, 419: 418, 420: 419, 421: 420, 422: 421, 423: 422, 424: 423, 425: 424, 426: 425, 427: 426, 428: 427, 429: 428, 430: 429, 431: 430, 432: 431, 433: 432, 434: 433, 435: 434, 436: 435, 437: 436, 438: 437, 439: 438, 440: 439, 441: 440, 442: 441, 443: 442, 444: 443, 445: 444, 446: 445, 447: 446, 448: 447, 449: 448, 450: 449, 451: 450, 452: 451, 453: 452, 454: 453, 455: 454, 456: 455, 457: 456, 458: 457, 459: 458, 460: 459, 461: 460, 462: 461, 463: 462, 464: 463, 465: 464, 466: 465, 467: 466, 468: 467, 469: 468, 470: 469, 471: 470, 472: 471, 473: 472, 474: 473, 475: 474, 476: 475, 477: 476, 478: 477, 479: 478, 480: 479, 481: 480, 482: 481, 483: 482, 484: 483, 485: 484, 486: 485, 487: 486, 488: 487, 489: 488, 490: 489, 491: 490, 492: 491, 493: 492, 494: 493, 495: 494, 496: 495, 497: 496, 498: 497, 499: 498, 500: 499, 501: 500, 502: 501, 503: 502, 504: 503, 505: 504, 506: 505, 507: 506, 508: 507, 509: 508, 510: 509, 511: 510, 512: 511, 513: 512, 514: 513, 515: 514, 516: 515, 517: 516, 518: 517, 519: 518, 520: 519, 521: 520, 522: 521, 523: 522, 524: 523, 525: 524, 526: 525, 527: 526, 528: 527, 529: 528, 530: 529, 531: 530, 532: 531, 533: 532, 534: 533, 535: 534, 536: 535, 537: 536, 538: 537, 539: 538, 540: 539, 541: 540, 542: 541, 543: 542, 544: 543, 545: 544, 546: 545, 547: 546, 548: 547, 549: 548, 550: 549, 551: 550, 552: 551, 553: 552, 554: 553, 555: 554, 556: 555, 557: 556, 558: 557, 559: 558, 560: 559, 561: 560, 562: 561, 563: 562, 564: 563, 565: 564, 566: 565, 567: 566, 568: 567, 569: 568, 570: 569, 571: 570, 572: 571, 573: 572, 574: 573, 575: 574, 576: 575, 577: 576, 578: 577, 579: 578, 580: 579, 581: 580, 582: 581, 583: 582, 584: 583, 585: 584, 586: 585, 587: 586, 588: 587, 589: 588, 590: 589, 591: 590, 592: 591, 593: 592, 594: 593, 595: 594, 596: 595, 597: 596, 598: 597, 599: 598, 600: 599, 601: 600, 602: 601, 603: 602, 604: 603, 605: 604, 606: 605, 607: 606, 608: 607, 609: 608, 610: 609}\n"
     ]
    }
   ],
   "source": [
    "_, user_mapping = load_node_csv(ratings_path, index_col='userId')\n",
    "print(user_mapping)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e771a5a",
   "metadata": {},
   "source": [
    "这样，我们就可以初始化HeteroData对象，并将两种节点类型传递给它："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8ad83005",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HeteroData(\n",
      "  \u001b[1muser\u001b[0m={ num_nodes=610 },\n",
      "  \u001b[1mmovie\u001b[0m={ x=[9742, 404] }\n",
      ")\n",
      "torch.Size([9742, 404])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.data import HeteroData\n",
    "\n",
    "data = HeteroData()\n",
    "\n",
    "data['user'].num_nodes = len(user_mapping)  # \"user\"没有任何特征\n",
    "data['movie'].x = movie_x\n",
    "\n",
    "print(data)\n",
    "print(movie_x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a788cee4",
   "metadata": {},
   "source": [
    "由于用户没有任何节点级别的信息，我们只定义其节点数。因此，在异构图模型的训练过程中，我们可能需要通过torch.nn.Embedding以端到端的方式学习不同的用户嵌入。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9737466",
   "metadata": {},
   "source": [
    "接下来，我们来看看根据用户的评分将他们与电影联系起来。为此，我们定义了一个方法load_edge_csv()，该方法从ratings.csv返回shape[2，num_ratings]的最终edge_index表示，以及原始*.csv文件中存在的任何其他功能："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "be64f930",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,\n",
    "                  encoders=None, **kwargs):\n",
    "    df = pd.read_csv(path, **kwargs)\n",
    "\n",
    "    src = [src_mapping[index] for index in df[src_index_col]]\n",
    "    dst = [dst_mapping[index] for index in df[dst_index_col]]\n",
    "    #print(len(src))\n",
    "    #print(len(dst))\n",
    "    edge_index = torch.tensor([src, dst])\n",
    "\n",
    "    edge_attr = None\n",
    "    if encoders is not None:\n",
    "        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]\n",
    "        edge_attr = torch.cat(edge_attrs, dim=-1)\n",
    "        \n",
    "    #print(edge_attr.shape)\n",
    "\n",
    "    return edge_index, edge_attr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "129c3a3a",
   "metadata": {},
   "source": [
    "这里，src_index_col和dst_index_col分别定义源节点和目标节点的索引列。我们进一步利用节点级映射src_mapping和dst_mapping来确保原始索引在我们的最终表示中被映射到正确的连续索引。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7e5270c",
   "metadata": {},
   "source": [
    "对于文件中定义的每条边，它会在src_mapping和dst_mapping中查找正向索引，并适当地移动数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c860f17",
   "metadata": {},
   "source": [
    "与load_node_csv()类似，编码器用于返回额外的边特征信息。例如，为了从ratings.csv中的rating列加载ratings，我们可以定义一个IdentityEncoder，它只需将浮点值列表转换为PyTorch张量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "76aceef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class IdentityEncoder:\n",
    "    def __init__(self, dtype=None):\n",
    "        self.dtype = dtype\n",
    "\n",
    "    def __call__(self, df):\n",
    "        return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d14c381",
   "metadata": {},
   "source": [
    "这样，我们就可以完成我们的HeteroData对象了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3ff9efb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HeteroData(\n",
      "  \u001b[1muser\u001b[0m={ num_nodes=610 },\n",
      "  \u001b[1mmovie\u001b[0m={ x=[9742, 404] },\n",
      "  \u001b[1m(user, rates, movie)\u001b[0m={\n",
      "    edge_index=[2, 100836],\n",
      "    edge_label=[100836, 1]\n",
      "  }\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "edge_index, edge_label = load_edge_csv(\n",
    "    ratings_path,\n",
    "    src_index_col='userId',\n",
    "    src_mapping=user_mapping,\n",
    "    dst_index_col='movieId',\n",
    "    dst_mapping=movie_mapping,\n",
    "    encoders={'rating': IdentityEncoder(dtype=torch.long)},\n",
    ")\n",
    "\n",
    "data['user', 'rates', 'movie'].edge_index = edge_index\n",
    "data['user', 'rates', 'movie'].edge_label = edge_label\n",
    "\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4092421f",
   "metadata": {},
   "source": [
    "该HeteroData对象是PyG中异构图的原生格式，可以用作异构图模型的输入。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep_learning",
   "language": "python",
   "name": "deep_learning"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
